-
Notifications
You must be signed in to change notification settings - Fork 10.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: use mma PTX instructions for FlashAttention #11583
CUDA: use mma PTX instructions for FlashAttention #11583
Conversation
The CUDA 11.7 compilation problems on Windows seem to be because the |
@@ -1775,7 +1775,7 @@ extern "C" { | |||
struct ggml_tensor * a, | |||
int k); | |||
|
|||
#define GGML_KQ_MASK_PAD 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor note for the future: we should hide this constant behind an API call
Can there be a ptx kernel for Pascal via Dp4a? |
For Pascal I am not aware of any useful PTX instructions that are not already being used in CUDA code. As of right now quantized KV caches are converted to floating point numbers for large batch FlashAttention. It would in principle be possible to write a FlashAttention kernel that uses int8 arithmetic instead. This would very likely be faster on Pascal, on Turing or newer it's a bit unclear since it's difficult to get good int8 tensor core utilization with the GGML quantization formats. |
I pushed a version with a workaround for |
While going through the PTX ISA documentation I noticed that there is a table that maps PTX versions to CUDA versions. Using that table I could determine that |
Co-authored-by: Diego Devesa <slarengh@gmail.com>
Co-authored-by: Diego Devesa <slarengh@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are multiple issues for this to work on amd mfma (and lets forget about wmma), for one the shapes are quite non-ideal, all amd shapes are square and 8 is not among them (4x4xN, 16x16xN, 32x32xN, 4x4 ofc being the least efficient), the output layout is different (and has its own private register space we would have to move it out of), maybe you can make it perform better than the vector implementation but it wont be ideal at all. Its not something i will be trying to tackle soon, as there are lower hanging fruit.
I wish we could keep the wmma implementation, as this performs decently on mfma/rocmwmma for large batch sizes
@@ -25,6 +25,7 @@ | |||
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice | |||
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite | |||
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} | |||
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) | |||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The synchronous version of the shuffle operators has been supported since rocm 6.2, probably we should just add the ability to conditionally use them based on HIP_VERSION >= 60200000 even when this wont do anything for now, just so it do sent get forgotten in the future.
For CUDA the implementation is definitely bad and in a vacuum I would not want to invest the effort to maintain it long-term. However, I recognize that you're making valuable contributions to the project and I'm willing to in return maintain the wmma code. I would also be willing to merge an optional compilation argument for rocwmma if I get a pledge from you that you will maintain it (since I currently do not have any suitable hardware for testing, may change in March). But I must stress that relative to what the hardware would be capable of the performance of a HIP port of my wmma code will always be bad. Ideally we would just have a ROCm implementation. I would be willing to learn ROCm in order to review any related code. |
Dont get me wrong, its bad, it just performs better than the other attn implementations available in ggml for now. Lets discuss details at a later date, as for now as i understand it its in no intimidate risk of being removed, as it still benefits volta also. |
@JohannesGaessler After this change, the How to modify the ggml-ci script to fix this? |
* CUDA: use mma PTX instructions for FlashAttention * __shfl_sync workaround for movmatrix * add __shfl_sync to HIP Authors : Johannes Gaessler and Slaren
* CUDA: use mma PTX instructions for FlashAttention * __shfl_sync workaround for movmatrix * add __shfl_sync to HIP Co-authored-by: Diego Devesa <slarengh@gmail.com>
This PR replaces the WMMA-based CUDA FlashAttention kernel with a kernel that instead uses PTX instructions to access tensor cores. These kernels are typically used for batch sizes >> 1 but also for batched inference. The principle of the new kernel is the same as with MMQ: in
mmq.cuh
there are primitives that expose otherwise inaccessible PTX instructions to CUDA code. The primitives are similar to the WMMA interface but they have a well-defined data layout which allows for better optimization.The data layout is the same for FP16 and int8 and from Turing onward. I replaced
INT8_MMA_AVAILABLE
withNEW_MMA_AVAILABLE
to better reflect this. The tensor cores on Volta are not compatible for the new code, for V100s the old code is used. Long-term I plan to purchase a V100 and write a dedicated kernel. There was interest from @IMbackK regarding the use of AMD tensor cores; if they could be made to fit the interface inmmq.cuh
they could in principle work but to me this seems to be an unlikely prospect.t/s by batch size
t/s by prompt length
Performance for large batch sizes is good, for small batch sizes it's still suboptimal. Asymptotically the speedup for long prompts is about 1.1-1.5x on the GPUs I've tested. Notably the new kernel uses a stream-k decomposition though so the performance should generalize better beyond the GPUs that I as a dev am optimizing performance for. Also so far I've restricted the implementation to features that are available with Turing, there are some Ampere features that should be useful.
The file size of
libggml-cuda.so
withGGML_NATIVE=OFF
decreases from 363 MB to 358 MB. There is no change in compilation time on my multithreaded system because the compilation is waiting for MMQ.